import logging

import torch

from agents.interfaces import Learner

class CoordAgent():

    def __init__(self,envs,context,args,policy,rollouts,log_interval=10):
        assert isinstance(envs,Learner),"envs of a coord agent must be a Learner"
        self.context=context
        self.args=args
        self.envs=envs
        self.policy = policy
        self.rollouts = rollouts
        self.num_updates=0
        self.total_num_steps =0
        self.eval_mode =False
        self.log_interval = log_interval
        self.state=None
        self.final_state=None
        self.act_state=None

    def reset(self,**kwargs):
        self.obs=self.envs.reset(**kwargs)
        if self.args.env_type == "multiworld" or self.args.env_type == "maze" or self.args.env_type =="gridworld":
            if self.args.state:
                self.state = torch.tensor(self.obs["state"])
            self.obs = self.obs["observation"]
        return self.obs

    def step(self,eval=False,render=False,**kwargs):
        done=False
        current_step =0
        self.tmp_rew = 0
        while current_step != self.args.max_steps and not done:
            change_goal = (current_step == 0) or (self.args.plan_interval > 0 and current_step%self.args.plan_interval==0)
            if change_goal:
                ###Computing the action
                with torch.no_grad():
                    goal,cluster= self.policy.act(self.obs,step=current_step,state=self.state)
            self.obs, _, reward, done, infos = self.envs.step(goal,goal_cluster=cluster,step=current_step,eval=eval,render=render,act_state=self.policy.act_state)
            if self.args.env_type == "multiworld" or self.args.env_type == "maze" or self.args.env_type =="gridworld":
                if self.args.state:
                    self.state = torch.tensor(self.obs["state"])
                self.obs = self.obs["observation"]
            current_step += 1
            self.tmp_rew += reward*(1-done)
        self.tmp_rew /= current_step
        self.total_num_steps += 1
        self.goal=goal
        if infos[0]["true_obs"] is not None:
            self.obs = infos[0]["true_obs"].unsqueeze(0)
        if self.args.state:
            self.final_state = torch.tensor(infos[0]["true_state"])
            self.act_state = self.policy.final_act_state if self.policy.final_act_state is not None else self.policy.act_state

        return self.obs,goal, reward, done, infos

    def learn(self):
        if self.eval_mode:
            return
        self.num_updates+=1
        ret= self.tmp_rew
        self.policy.evaluate(ret.detach(),self.goal,self.obs,final_state=self.final_state,act_state=self.act_state)

        return self.policy

    def print(self):
        pass

    def can_learn(self):
        return True

    def eval(self):
        self.eval_mode=True

    def train(self):
        self.eval_mode=False

    def after_update(self):
        pass

    def save(self):
        self.policy.save()
        self.envs.save()

    def load(self):
        self.policy.load()
        self.envs.load()

    def clone(self, envs,**kwargs):
        return CoordAgent(envs,self.context,self.args,self.policy,self.rollouts)

